{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "slideshow": {
     "slide_type": "notes"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "from IPython.core.interactiveshell import InteractiveShell\n",
    "InteractiveShell.ast_node_interactivity = 'all'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# 读写文件"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "- 希望保存训练的模型，以备将来在各种环境中使用（比如在部署中进行预测）\n",
    "- 当运行一个耗时较长的训练过程时，最佳的做法是定期保存中间结果"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "## 加载和保存张量"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 0,
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "- 对单个张量，可以直接调用`load`和`save`函数分别读写它们"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-07T16:57:14.320840Z",
     "iopub.status.busy": "2022-12-07T16:57:14.320234Z",
     "iopub.status.idle": "2022-12-07T16:57:15.480792Z",
     "shell.execute_reply": "2022-12-07T16:57:15.479939Z"
    },
    "origin_pos": 2,
    "slideshow": {
     "slide_type": "fragment"
    },
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "x = torch.arange(4)\n",
    "torch.save(x, 'x-file')  # 存储的文件名称为x-file"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 5,
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "- 将存储在文件中的数据读回内存"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-07T16:57:15.484795Z",
     "iopub.status.busy": "2022-12-07T16:57:15.484187Z",
     "iopub.status.idle": "2022-12-07T16:57:15.496200Z",
     "shell.execute_reply": "2022-12-07T16:57:15.495423Z"
    },
    "origin_pos": 7,
    "slideshow": {
     "slide_type": "fragment"
    },
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0, 1, 2, 3])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x2 = torch.load('x-file')\n",
    "x2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 10,
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "- 存储一个**张量列表**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-07T16:57:15.499724Z",
     "iopub.status.busy": "2022-12-07T16:57:15.499071Z",
     "iopub.status.idle": "2022-12-07T16:57:15.509309Z",
     "shell.execute_reply": "2022-12-07T16:57:15.508580Z"
    },
    "origin_pos": 12,
    "slideshow": {
     "slide_type": "fragment"
    },
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "y = torch.zeros(4)\n",
    "torch.save([x, y],'x-files')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "- 将张量列表读回内存"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x2, y2 = torch.load('x-files')\n",
    "(x2, y2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 15,
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "- 存储从字符串映射到张量的字典"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-07T16:57:15.512722Z",
     "iopub.status.busy": "2022-12-07T16:57:15.512205Z",
     "iopub.status.idle": "2022-12-07T16:57:15.519128Z",
     "shell.execute_reply": "2022-12-07T16:57:15.518294Z"
    },
    "origin_pos": 17,
    "slideshow": {
     "slide_type": "fragment"
    },
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "mydict = {'x': x, 'y': y}\n",
    "torch.save(mydict, 'mydict')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "- 将张量字典读回到内存中"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mydict2 = torch.load('mydict')\n",
    "mydict2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## 加载和保存模型参数"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 20,
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "- 深度学习框架提供了内置函数来保存和加载整个网络"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "- 需要注意的是，这将保存模型的**参数**而**不是保存整个模型**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "- 为了恢复模型，需要用代码生成模型架构，然后从磁盘加载参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-07T16:57:15.522400Z",
     "iopub.status.busy": "2022-12-07T16:57:15.521895Z",
     "iopub.status.idle": "2022-12-07T16:57:15.528483Z",
     "shell.execute_reply": "2022-12-07T16:57:15.527720Z"
    },
    "origin_pos": 22,
    "slideshow": {
     "slide_type": "fragment"
    },
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "输出Y为\n",
      "tensor([[-0.0802,  0.2133, -0.1663,  0.1090, -0.1764, -0.0980, -0.0377,  0.0478,\n",
      "          0.3004,  0.0946],\n",
      "        [-0.1499,  0.2987,  0.1000, -0.0628, -0.1497, -0.2791, -0.1915, -0.0105,\n",
      "          0.1062,  0.3925]], grad_fn=<AddmmBackward0>)\n"
     ]
    }
   ],
   "source": [
    "# 定义一个三层感知机\n",
    "\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.hidden = nn.Linear(20, 256)\n",
    "        self.output = nn.Linear(256, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.output(F.relu(self.hidden(x)))\n",
    "\n",
    "net = MLP()  # 实例化自定义的感知机\n",
    "X = torch.randn(size=(2, 20))  # 输入数据\n",
    "Y = net(X)\n",
    "\n",
    "print(f'输出Y为\\n{Y}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 25,
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "- 将模型参数存储在“mlp.params”的文件中"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-07T16:57:15.531616Z",
     "iopub.status.busy": "2022-12-07T16:57:15.531209Z",
     "iopub.status.idle": "2022-12-07T16:57:15.535898Z",
     "shell.execute_reply": "2022-12-07T16:57:15.535152Z"
    },
    "origin_pos": 27,
    "slideshow": {
     "slide_type": "fragment"
    },
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "torch.save(net.state_dict(), 'mlp.params')\n",
    "\n",
    "# 注意，net.state_dict()可以获得模型的所有参数"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 30,
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "- 为了恢复模型，**实例化原始多层感知机模型**\n",
    "- 不需要随机初始化模型参数，而是**直接读取文件中存储的参数**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-07T16:57:15.539128Z",
     "iopub.status.busy": "2022-12-07T16:57:15.538719Z",
     "iopub.status.idle": "2022-12-07T16:57:15.545525Z",
     "shell.execute_reply": "2022-12-07T16:57:15.544793Z"
    },
    "origin_pos": 32,
    "slideshow": {
     "slide_type": "fragment"
    },
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "MLP(\n",
       "  (hidden): Linear(in_features=20, out_features=256, bias=True)\n",
       "  (output): Linear(in_features=256, out_features=10, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clone = MLP()  # 建立模型的架构\n",
    "clone.load_state_dict(torch.load('mlp.params'))  # 加载模型参数\n",
    "clone.eval()   # 查看模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "- ```python\n",
    "torch.nn.Module.load_state_dict(state_dict,strict=True)\n",
    "```\n",
    "    - 将`state_dict`保存的模型参数注入到块中\n",
    "    - `strict`为`True`，强制确保`state_dict`中的`keys`与模型`state_dict()`函数中保存的`keys`一致"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-12-07T16:57:15.549346Z",
     "iopub.status.busy": "2022-12-07T16:57:15.548468Z",
     "iopub.status.idle": "2022-12-07T16:57:15.555963Z",
     "shell.execute_reply": "2022-12-07T16:57:15.555173Z"
    },
    "origin_pos": 37,
    "slideshow": {
     "slide_type": "slide"
    },
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[True, True, True, True, True, True, True, True, True, True],\n",
       "        [True, True, True, True, True, True, True, True, True, True]])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 检验加载的模型参数，给定同样的输入X，两个模型的输出应当一样\n",
    "\n",
    "Y_clone = clone(X)\n",
    "Y_clone == Y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## 加载和保存模型架构+参数"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "```python\n",
    "torch.save(model, path)\n",
    "```\n",
    "\n",
    "- 保存模型+参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(net,'mlp')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "```python\n",
    "torch.load(path)\n",
    "```\n",
    "\n",
    "- 加载模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "outputs": [],
   "source": [
    "net2 = torch.load('mlp')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[True, True, True, True, True, True, True, True, True, True],\n",
       "        [True, True, True, True, True, True, True, True, True, True]])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 检验加载的模型，给定同样的输入X，net2和net的输出应当一样\n",
    "\n",
    "Y_2 = net2(X)\n",
    "Y_2 == Y"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "幻灯片",
  "hide_input": false,
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": true,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
